import dataclasses
import json
import logging
import pprint
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Union, cast, runtime_checkable

import humanfriendly
import pydantic
from pydantic import BaseModel
from tabulate import tabulate
from typing_extensions import Literal, Protocol

from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.mcp_builder import mcps_from_mce
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.report_helpers import format_datetime_relative
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.autogenerated.lineage_helper import is_lineage_aspect
from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent
from datahub.metadata.schema_classes import (
    MetadataChangeProposalClass,
    StatusClass,
    SubTypesClass,
    UpstreamLineageClass,
)
from datahub.utilities.file_backed_collections import FileBackedDict
from datahub.utilities.lossy_collections import LossyList
from datahub.utilities.urns.urn import guess_platform_name

logger = logging.getLogger(__name__)
LogLevel = Literal["ERROR", "WARNING", "INFO", "DEBUG"]


@runtime_checkable
class SupportsAsObj(Protocol):
    def as_obj(self) -> dict: ...


@dataclass
class Report(SupportsAsObj):
    def __post_init__(self) -> None:
        self.platform: Optional[str] = None

    def set_platform(self, platform: str) -> None:
        self.platform = platform

    def get_platform(self) -> Optional[str]:
        return self.platform

    @staticmethod
    def to_str(some_val: Any) -> str:
        if isinstance(some_val, Enum):
            return some_val.name
        else:
            return str(some_val)

    @staticmethod
    def to_pure_python_obj(some_val: Any) -> Any:
        """A cheap way to generate a dictionary."""

        if isinstance(some_val, SupportsAsObj):
            return some_val.as_obj()
        elif isinstance(some_val, pydantic.BaseModel):
            return Report.to_pure_python_obj(some_val.model_dump())
        elif dataclasses.is_dataclass(some_val) and not isinstance(some_val, type):
            # The `is_dataclass` function returns `True` for both instances and classes.
            # We need an extra check to ensure an instance was passed in.
            # https://docs.python.org/3/library/dataclasses.html#dataclasses.is_dataclass
            return dataclasses.asdict(some_val)
        elif isinstance(some_val, list):
            return [Report.to_pure_python_obj(v) for v in some_val if v is not None]
        elif isinstance(some_val, timedelta):
            return humanfriendly.format_timespan(some_val)
        elif isinstance(some_val, datetime):
            try:
                return format_datetime_relative(some_val)
            except Exception:
                # we don't want to fail reporting because we were unable to pretty print a timestamp
                return str(datetime)
        elif isinstance(some_val, dict):
            return {
                Report.to_str(k): Report.to_pure_python_obj(v)
                for k, v in some_val.items()
                if v is not None
            }
        elif isinstance(some_val, (int, float, bool)):
            return some_val
        else:
            # fall through option
            return Report.to_str(some_val)

    def compute_stats(self) -> None:
        """A hook to compute derived stats"""
        pass

    def as_obj(self) -> dict:
        self.compute_stats()
        return {
            str(key): Report.to_pure_python_obj(value)
            for (key, value) in self.__dict__.items()
            # ignore nulls and fields starting with _
            if value is not None and not str(key).startswith("_")
        }

    def as_string(self) -> str:
        self_obj = self.as_obj()
        _aspects_by_subtypes = self_obj.pop("aspects_by_subtypes", None)

        # Format the main report data
        result = pprint.pformat(self_obj, width=150, sort_dicts=False)

        # Add aspects_by_subtypes table if it exists
        if _aspects_by_subtypes:
            result += "\n\nAspects by Subtypes:\n"
            result += self._format_aspects_by_subtypes_table(_aspects_by_subtypes)

        return result

    def _format_aspects_by_subtypes_table(
        self, aspects_by_subtypes: Dict[str, Dict[str, Dict[str, int]]]
    ) -> str:
        """Format aspects_by_subtypes data as a table with aspects as rows and entity/subtype as columns."""
        if not aspects_by_subtypes:
            return "No aspects by subtypes data available."

        all_aspects: set[str] = {
            aspect
            for subtypes in aspects_by_subtypes.values()
            for aspects in subtypes.values()
            for aspect in aspects
        }

        aspect_rows = sorted(all_aspects)

        entity_subtype_columns = []
        for entity_type, subtypes in aspects_by_subtypes.items():
            for subtype in subtypes:
                entity_subtype_columns.append(f"{entity_type} ({subtype})")

        entity_subtype_columns.sort()

        headers = ["Aspect"] + entity_subtype_columns

        table_data = [
            [aspect]
            + [
                aspects.get(aspect, 0)
                for subtypes in aspects_by_subtypes.values()
                for aspects in subtypes.values()
            ]
            for aspect in aspect_rows
        ]

        if table_data:
            return tabulate(table_data, headers=headers, tablefmt="grid")
        else:
            return "No aspects by subtypes data available."

    def as_json(self) -> str:
        return json.dumps(self.as_obj())

    # TODO add helper method for warning / failure status + counts?


@dataclass
class SourceReportSubtypes:
    urn: str
    entity_type: str
    subType: str = field(default="unknown")
    aspects: Dict[str, int] = field(default_factory=dict)
    soft_deleted: bool = field(default=False)


class ReportAttribute(BaseModel):
    severity: LogLevel = "DEBUG"
    help: Optional[str] = None

    @property
    def logger_sev(self) -> int:
        log_levels = {
            "DEBUG": logging.DEBUG,
            "INFO": logging.INFO,
            "WARNING": logging.WARNING,
            "ERROR": logging.ERROR,
        }
        return log_levels[self.severity]

    def log(self, msg: str) -> None:
        logger.log(level=self.logger_sev, msg=msg, stacklevel=3)


@dataclass
class ExamplesReport(Report, Closeable):
    aspects: Dict[str, Dict[str, int]] = field(
        default_factory=lambda: defaultdict(lambda: defaultdict(int))
    )
    # This counts existence of aspects for each entity/subtype
    # This is used for the UI to calculate %age of entities with the aspect
    aspects_by_subtypes: Dict[str, Dict[str, Dict[str, int]]] = field(
        default_factory=lambda: defaultdict(
            lambda: defaultdict(lambda: defaultdict(int))
        )
    )
    # This counts all aspects for each entity/subtype
    aspects_by_subtypes_full_count: Dict[str, Dict[str, Dict[str, int]]] = field(
        default_factory=lambda: defaultdict(
            lambda: defaultdict(lambda: defaultdict(int))
        )
    )
    samples: Dict[str, Dict[str, List[str]]] = field(
        default_factory=lambda: defaultdict(lambda: defaultdict(list))
    )
    compute_stats_time_seconds: float = 0.0
    _file_based_dict: Optional[FileBackedDict[SourceReportSubtypes]] = None

    # We are adding this to make querying easier for fine-grained lineage
    _fine_grained_lineage_special_case_name = "fineGrainedLineages"
    _samples_to_add: int = 20
    _lineage_aspects_seen: Set[str] = field(default_factory=set)

    def __post_init__(self) -> None:
        super().__post_init__()
        self._file_based_dict = FileBackedDict(
            tablename="urn_aspects",
            extra_columns={
                "urn": lambda val: val.urn,
                "entityType": lambda val: val.entity_type,
                "subTypes": lambda val: val.subType,
                "aspects": lambda val: json.dumps(val.aspects),
                "soft_deleted": lambda val: val.soft_deleted,
            },
        )

    def close(self) -> None:
        self.compute_stats()
        if self._file_based_dict is not None:
            self._file_based_dict.close()
            self._file_based_dict = None

    def _build_aspects_where_clause(self, aspects: List[str]) -> str:
        """Build WHERE clause for matching any of the given aspects."""
        if not aspects:
            return ""

        conditions = []
        for aspect in aspects:
            conditions.append(f"aspects LIKE '%{aspect}%'")

        return " OR ".join(conditions)

    def _collect_samples_by_subtype(self, where_clause: str, sample_key: str) -> None:
        """Helper method to collect samples organized by subtype for a given where clause."""

        subtype_query = f"""
        SELECT DISTINCT subTypes
        FROM urn_aspects 
        WHERE {where_clause}
        """
        assert self._file_based_dict is not None
        subtypes = set()
        for row in self._file_based_dict.sql_query(subtype_query):
            sub_type = row["subTypes"] or "unknown"
            subtypes.add(sub_type)

        for sub_type in subtypes:
            query = f"""
            SELECT urn
            FROM urn_aspects 
            WHERE {where_clause} AND subTypes = ?
            limit {self._samples_to_add}
            """

            for row in self._file_based_dict.sql_query(query, (sub_type,)):
                self.samples[sample_key][sub_type].append(row["urn"])

    def _collect_samples_by_aspects(self, aspects: List[str], sample_key: str) -> None:
        """Helper method to collect samples for entities that have any of the given aspects."""
        if not aspects:
            return

        where_clause = self._build_aspects_where_clause(aspects)
        self._collect_samples_by_subtype(where_clause, sample_key)

    def _collect_samples_by_lineage_aspects(
        self, aspects: List[str], sample_key: str
    ) -> None:
        """Helper method to collect samples for entities that have any of the given lineage aspects.

        Lineage aspects are stored in JSON format and require quote escaping in LIKE clauses.
        """
        if not aspects:
            return

        lineage_conditions = []
        for aspect in aspects:
            lineage_conditions.append(f"aspects LIKE '%\"{aspect}\"%'")

        where_clause = " OR ".join(lineage_conditions)
        self._collect_samples_by_subtype(where_clause, sample_key)

    def _collect_samples_with_all_conditions(self, sample_key: str) -> None:
        """
        Collect samples for entities that have lineage, profiling, and usage aspects.
        These specific 3 cases are added here as these URNs will be shown in the UI. Subject to change in future.
        """
        if not self._lineage_aspects_seen:
            return
        assert self._file_based_dict is not None

        # Build lineage conditions using the same logic as _collect_samples_by_lineage_aspects
        lineage_conditions = []
        for aspect in self._lineage_aspects_seen:
            lineage_conditions.append(f"aspects LIKE '%\"{aspect}\"%'")
        lineage_where_clause = " OR ".join(lineage_conditions)

        # Build profiling conditions using the same logic as _collect_samples_by_aspects
        profiling_where_clause = self._build_aspects_where_clause(["datasetProfile"])

        # Build usage conditions using the same logic as _collect_samples_by_aspects
        usage_where_clause = self._build_aspects_where_clause(
            [
                "datasetUsageStatistics",
                "chartUsageStatistics",
                "dashboardUsageStatistics",
            ]
        )

        query = f"""
        SELECT urn, subTypes
        FROM urn_aspects
        WHERE ({lineage_where_clause})
        AND ({profiling_where_clause})
        AND ({usage_where_clause})
        limit {self._samples_to_add}
        """

        for row in self._file_based_dict.sql_query(query):
            sub_type = row["subTypes"] or "unknown"
            self.samples[sample_key][sub_type].append(row["urn"])

    def _has_fine_grained_lineage(
        self, mcp: Union[MetadataChangeProposalClass, MetadataChangeProposalWrapper]
    ) -> bool:
        if isinstance(mcp.aspect, UpstreamLineageClass):
            upstream_lineage = cast(UpstreamLineageClass, mcp.aspect)
            if upstream_lineage.fineGrainedLineages:
                return True
        return False

    def _update_file_based_dict(
        self,
        urn: str,
        entityType: str,
        aspectName: str,
        mcp: Union[MetadataChangeProposalClass, MetadataChangeProposalWrapper],
    ) -> None:
        platform_name = guess_platform_name(urn)
        if platform_name != self.get_platform():
            return
        if is_lineage_aspect(entityType, aspectName):
            self._lineage_aspects_seen.add(aspectName)
        has_fine_grained_lineage = self._has_fine_grained_lineage(mcp)

        sub_type = "unknown"
        if isinstance(mcp.aspect, SubTypesClass):
            sub_type = mcp.aspect.typeNames[0]

        assert self._file_based_dict is not None
        if urn in self._file_based_dict:
            if sub_type != "unknown":
                self._file_based_dict[urn].subType = sub_type
            aspects_dict = self._file_based_dict[urn].aspects
            if aspectName in aspects_dict:
                aspects_dict[aspectName] += 1
            else:
                aspects_dict[aspectName] = 1
            if has_fine_grained_lineage:
                if self._fine_grained_lineage_special_case_name in aspects_dict:
                    aspects_dict[self._fine_grained_lineage_special_case_name] += 1
                else:
                    aspects_dict[self._fine_grained_lineage_special_case_name] = 1
            self._file_based_dict.mark_dirty(urn)
        else:
            aspects_dict = {aspectName: 1}
            if has_fine_grained_lineage:
                aspects_dict[self._fine_grained_lineage_special_case_name] = 1
            self._file_based_dict[urn] = SourceReportSubtypes(
                urn=urn,
                entity_type=entityType,
                subType=sub_type,
                aspects=aspects_dict,
            )
        if (
            isinstance(mcp.aspect, StatusClass)
            and mcp is not None
            and mcp.aspect is not None
        ):
            self._file_based_dict[urn].soft_deleted = mcp.aspect.removed
            self._file_based_dict.mark_dirty(urn)

    def _store_workunit_data(self, wu: MetadataWorkUnit) -> None:
        urn = wu.get_urn()

        if not isinstance(wu.metadata, MetadataChangeEvent):
            mcps = [wu.metadata]
        else:
            mcps = list(mcps_from_mce(wu.metadata))

        for mcp in mcps:
            entityType = mcp.entityType
            aspectName = mcp.aspectName

            if aspectName is None:
                continue

            self._update_file_based_dict(urn, entityType, aspectName, mcp)

    def compute_stats(self) -> None:
        start_time = datetime.now()
        if self._file_based_dict is None:
            return

        query = """
        SELECT entityType, subTypes, aspects, count(*) as count
        FROM urn_aspects 
        WHERE soft_deleted = 0
        GROUP BY entityType, subTypes, aspects
        """

        entity_subtype_aspect_counts: Dict[str, Dict[str, Dict[str, int]]] = (
            defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        )
        entity_subtype_aspect_counts_exist: Dict[str, Dict[str, Dict[str, int]]] = (
            defaultdict(lambda: defaultdict(lambda: defaultdict(int)))
        )

        for row in self._file_based_dict.sql_query(query):
            entity_type = row["entityType"]
            sub_type = row["subTypes"]
            count = row["count"]
            aspects_raw = row["aspects"] or "[]"

            aspects = json.loads(aspects_raw)
            for aspect, aspect_count in aspects.items():
                entity_subtype_aspect_counts[entity_type][sub_type][aspect] += (
                    aspect_count * count
                )
                entity_subtype_aspect_counts_exist[entity_type][sub_type][aspect] += (
                    count
                )

        self.aspects.clear()
        self.aspects_by_subtypes.clear()
        self.aspects_by_subtypes_full_count.clear()
        for entity_type, subtype_counts in entity_subtype_aspect_counts.items():
            for sub_type, aspect_counts in subtype_counts.items():
                for aspect, count in aspect_counts.items():
                    self.aspects[entity_type][aspect] += count
                self.aspects_by_subtypes_full_count[entity_type][sub_type] = dict(
                    aspect_counts
                )

        for entity_type, subtype_counts in entity_subtype_aspect_counts_exist.items():
            for sub_type, aspect_counts in subtype_counts.items():
                self.aspects_by_subtypes[entity_type][sub_type] = dict(aspect_counts)

        self.samples.clear()
        self._collect_samples_by_aspects(["datasetProfile"], "profiling")
        self._collect_samples_by_aspects(
            [
                "datasetUsageStatistics",
                "chartUsageStatistics",
                "dashboardUsageStatistics",
            ],
            "usage",
        )
        self._collect_samples_by_lineage_aspects(
            list(self._lineage_aspects_seen), "lineage"
        )
        self._collect_samples_with_all_conditions("all_3")
        end_time = datetime.now()
        self.compute_stats_time_seconds += (end_time - start_time).total_seconds()


class EntityFilterReport(ReportAttribute):
    type: str

    processed_entities: LossyList[str] = pydantic.Field(default_factory=LossyList)
    dropped_entities: LossyList[str] = pydantic.Field(default_factory=LossyList)

    def processed(self, entity: str, type: Optional[str] = None) -> None:
        self.log(f"Processed {type or self.type} {entity}")
        self.processed_entities.append(entity)

    def dropped(self, entity: str, type: Optional[str] = None) -> None:
        self.log(f"Filtered {type or self.type} {entity}")
        self.dropped_entities.append(entity)

    def as_obj(self) -> dict:
        return {
            "filtered": self.dropped_entities.as_obj(),
            "processed": self.processed_entities.as_obj(),
        }

    @staticmethod
    def field(type: str, severity: LogLevel = "DEBUG") -> "EntityFilterReport":
        """A helper to create a dataclass field."""

        return dataclasses.field(
            default_factory=lambda: EntityFilterReport(type=type, severity=severity)
        )
